import logging

from collections import defaultdict
from lab.reports import Table, CellFormatter
from downward.reports import PlanningReport

import matplotlib as mpl
#mpl.use('Agg')
import matplotlib.pyplot as plt
import numpy as np

def toString(f):
    return str("{0:.1f}".format(f)) 

def valid(attribute):
    return (attribute != -1 and attribute is not None) 

class FmmPlotReport(PlanningReport):
    """
    This report is exclusively designed to work with the fMM algorithm. 
    It will create multiple graphs for each problem instance and a summary graph
    for whole domains.

    >>> from downward.experiment import FastDownwardExperiment
    >>> exp = FastDownwardExperiment()
    >>> exp.add_report(FmmPlotReport(
    ...     attributes=ALL_ATTRIBUTES),filter_algorithms=ALL_FMM_ALGOS)

    """

    def __init__(self, **kwargs):
        PlanningReport.__init__(self, **kwargs)

    def _translate_mp(self, mp):
        if '-' in mp:
            return (0, 0)
        parts = mp.split(', ')
        for part in parts:
            if 'f_g: ' in part:
                f_g = int(part[5:])
            if 'b_g: ' in part:
                b_g = int(part[5:])
            if 'f_h: ' in part:
                f_h = int(part[5:])
            if 'b_h: ' in part:
                b_h = int(part[5:])
        h = 1 if (b_h >= f_h) else 0
        s = 1 if (b_g >= f_g) else 0
        return (h, s) 

    def _get_image_markup(self, images):
        return ['['+image.split('/')[-1]+']' for image in sorted(images)]

    def _get_plots(self, attribute):
        # Prepare Variables
        ex = defaultdict(list)
        f_ex = defaultdict(list)
        prefix = 'f_'
        
        # Gather Data
        for (domain, problem), runs in sorted(self.problem_runs.items()):
            for run in runs:
                algo = run.get('algorithm')
                if (valid(run.get(attribute))):
                    ex[domain, problem].append(run.get(attribute))
                    f_ex[domain, problem].append(run.get(prefix+attribute))
                else:
                    ex[domain, problem].append(None)
                    f_ex[domain, problem].append(None)


        # Prepare plot
        t = range(0,len(self.algorithms))
        fraction = 1.0/(len(self.algorithms)-1.0)
        t = [fraction*x for x in t]
        plots = []
        data = defaultdict(list)
        label = []
        bar_width = 0.02

        # Make Problem Plots
        for (domain, problem), expanded in sorted(ex.items()):
            data[domain].append(expanded)
            if all(x is None for x in expanded) or all(x is 0 for x in expanded):
                continue
            label.append(problem)
            fig, ax = plt.subplots()
            f_ex_without_none = map(lambda x: 0 if x is None else x ,f_ex[domain,problem])
            ex_without_none = map(lambda x: 0 if x is None else x , expanded)
            ax.bar(t, f_ex_without_none, bar_width)
            ax.bar(t, [x1-x2 for (x1,x2) in zip(ex_without_none, f_ex_without_none)], bar_width, bottom = f_ex_without_none)
            ax.plot(t, expanded, 'bo--')
            ax.set(ylabel=attribute+' (#)', xlabel='fMM (p)',
            title='fMM Algorithm - ' + domain + ':' + problem)
            ax.set_xlim(t[0]-0.05,t[-1]+0.05)
            ax.grid()
            plots.append(self.outfile[0:-5]+domain+'.'+problem[0:-5]+'.'+attribute+'.png')
            fig.savefig(plots[-1])
        
        # Make Domain Plots
        fig_sum = {}
        ax_sum = {}
        for (domain, problem), expanded in ex.items():
            if (expanded.count(None)== len(expanded)):
                continue
            if domain in fig_sum:
                fig = fig_sum[domain]
                ax = ax_sum[domain]
            else:
                fig, ax = plt.subplots()
                fig_sum[domain] = fig
                ax_sum[domain] = ax
            lowest = min(i for i in expanded if i is not None)
            new_expanded = map(lambda x: None if x is None else ((x*1.0)/(lowest if lowest != 0 else 1)),expanded)
            ax.plot(t, new_expanded, 'bo--', label=problem)
        for (domain), ax in ax_sum.items():
            ax.set(ylabel=attribute+' (%)', xlabel='fMM (p)',
            title='fMM Algorithm - ' + domain)
            #ax.set_yscale('log')
            ax.grid()
        for (domain), fig in fig_sum.items():
            plots.append(self.outfile[0:-5]+domain+'.'+attribute+'.png')
            fig.savefig(plots[-1])

        return self._get_image_markup(plots)
 
    def get_markup(self): 
        plt.rcParams.update({'figure.max_open_warning':0})
        plots = self._get_plots('expanded')
        plots = plots + self._get_plots('jump_expanded')
        return '\n'.join(plots)


# List of properties:
# domains               : (domain), problems
# problems              : set (domain, problem)
# problem_runs          : (domain, problem), runs
# domain_algorithm_runs : (domain, algorithm), runs
# runs                  : (domain, problem, algo), run
# attributes
# algorithms            : set (algorithm)
# algorithm_info
